查看原文
其他

tensorflow Object Detection API使用预训练模型mask r-cnn实现对象检测

gloomyfish OpenCV学堂 2019-03-28

Mask R-CNN模型下载

Mask R-CNN是何凯明大神在2017年整出来的新网络模型,在原有的R-CNN基础上实现了区域ROI的像素级别分割。关于Mask R-CNN模型本身的介绍与解释网络上面已经是铺天盖地了,论文也是到处可以看到。这里主要想介绍一下在tensorflow中如何使用预训练的Mask R-CNN模型实现对象检测与像素级别的分割。tensorflow框架有个扩展模块叫做models里面包含了很多预训练的网络模型,提供给tensorflow开发者直接使用或者迁移学习使用,首先需要下载Mask R-CNN网络模型,这个在tensorflow的models的github上面有详细的解释与model zoo的页面介绍, tensorflow models的github主页地址如下: https://github.com/tensorflow/models

我这里下载的是:

mask_rcnn_inception_v2_coco_2018_01_28.tar.gz

下载好模型之后可以解压缩为tar文件,然后通过下面的代码读入模型

  1. MODEL_NAME = 'mask_rcnn_inception_v2_coco_2018_01_28'

  2. MODEL_FILE = 'D:/tensorflow/' + MODEL_NAME + '.tar'

  3. # Path to frozen detection graph

  4. PATH_TO_CKPT = MODEL_NAME + '/frozen_inference_graph.pb'

  5. # List of the strings that is used to add correct label for each box.

  6. PATH_TO_LABELS = os.path.join('D:/tensorflow/models/research/object_detection/data', 'mscoco_label_map.pbtxt')

  7. NUM_CLASSES = 90

  8. tar_file = tarfile.open(MODEL_FILE)

  9. for file in tar_file.getmembers():

  10.    file_name = os.path.basename(file.name)

  11.    if 'frozen_inference_graph.pb' in file_name:

  12.        tar_file.extract(file, os.getcwd())

  13. detection_graph = tf.Graph()

  14. with detection_graph.as_default():

  15.    od_graph_def = tf.GraphDef()

  16.    with tf.gfile.GFile(PATH_TO_CKPT, 'rb') as fid:

  17.        serialized_graph = fid.read()

  18.        od_graph_def.ParseFromString(serialized_graph)

  19.        tf.import_graph_def(od_graph_def, name='')

模型使用coco数据集,可以检测与分割90个对象类别,所以下面需要把对应labelmap文件读进去,这个文件在

models\research\objectdetection\data

目录下,实现代码如下:

  1. label_map = label_map_util.load_labelmap(PATH_TO_LABELS)

  2. categories = label_map_util.convert_label_map_to_categories(label_map, max_num_classes=NUM_CLASSES, use_display_name=True)

  3. category_index = label_map_util.create_category_index(categories)

有了这个之后就需要从模型中取出如下几个tensor

  • num_detections 表示检测对象数目

  • detection_boxes 表示输出框BB

  • detection_scores 表示得分

  • detection_classes 表示对象类别索引

  • detection_masks 表示mask分割

然后在会话中运行这几个tensor即可,代码实现如下:

  1. def run_inference_for_single_image(image, graph):

  2.    with graph.as_default():

  3.        with tf.Session() as sess:

  4.            # Get handles to input and output tensors

  5.            ops = tf.get_default_graph().get_operations()

  6.            all_tensor_names = {output.name for op in ops for output in op.outputs}

  7.            tensor_dict = {}

  8.            for key in ['num_detections', 'detection_boxes', 'detection_scores', 'detection_classes', 'detection_masks']:

  9.                tensor_name = key + ':0'

  10.                if tensor_name in all_tensor_names:

  11.                    tensor_dict[key] = tf.get_default_graph().get_tensor_by_name(tensor_name)

  12.            if 'detection_masks' in tensor_dict:

  13.                # The following processing is only for single image

  14.                detection_boxes = tf.squeeze(tensor_dict['detection_boxes'], [0])

  15.                detection_masks = tf.squeeze(tensor_dict['detection_masks'], [0])

  16.                # Reframe is required to translate mask from box coordinates to image coordinates and fit the image size.

  17.                real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)

  18.                detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])

  19.                detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])

  20.                detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(

  21.                    detection_masks, detection_boxes, image.shape[0], image.shape[1])

  22.                detection_masks_reframed = tf.cast(

  23.                    tf.greater(detection_masks_reframed, 0.5), tf.uint8)

  24.                # Follow the convention by adding back the batch dimension

  25.                tensor_dict['detection_masks'] = tf.expand_dims(

  26.                    detection_masks_reframed, 0)

  27.            image_tensor = tf.get_default_graph().get_tensor_by_name('image_tensor:0')

  28.            # Run inference

  29.            output_dict = sess.run(tensor_dict,

  30.                                 feed_dict={image_tensor: np.expand_dims(image, 0)})

  31.            # all outputs are float32 numpy arrays, so convert types as appropriate

  32.            output_dict['num_detections'] = int(output_dict['num_detections'][0])

  33.            output_dict['detection_classes'] = output_dict[

  34.              'detection_classes'][0].astype(np.uint8)

  35.            output_dict['detection_boxes'] = output_dict['detection_boxes'][0]

  36.            output_dict['detection_scores'] = output_dict['detection_scores'][0]

  37.            if 'detection_masks' in output_dict:

  38.                output_dict['detection_masks'] = output_dict['detection_masks'][0]

  39.        return output_dict

下面就是通过opencv来读取一张彩色测试图像,然后调用模型进行检测与对象分割,代码实现如下:

  1. image = cv2.imread("D:/apple.jpg");

  2. # image = cv2.imread("D:/tensorflow/models/research/object_detection/test_images/image2.jpg");

  3. cv2.imshow("input image", image)

  4. print(image.shape)

  5. # Actual detection.

  6. output_dict = run_inference_for_single_image(image, detection_graph)

  7. # Visualization of the results of a detection.

  8. vis_util.visualize_boxes_and_labels_on_image_array(

  9.    image,

  10.    output_dict['detection_boxes'],

  11.    output_dict['detection_classes'],

  12.    output_dict['detection_scores'],

  13.    category_index,

  14.    instance_masks=output_dict.get('detection_masks'),

  15.    use_normalized_coordinates=True,

  16.    line_thickness=8)

原图如下:

检测运行结果如下:

带mask分割效果如下:

官方测试图像运行结果:


【推荐阅读】

OpenCV Gabor滤波器实现纹理提取与缺陷分析

OpenCV中如何获得物体的主要方向

tensorflow中实现神经网络训练手写数字数据集mnist

新课程发布 - 《tensorflow零基础入门视频教程》

tensorflow中实现神经网络训练手写数字数据集mnist

Windows系统如何安装Tensorflow Object Detection API

使用Tensorflow Object Detection API实现对象检测


寇可为,我复亦为 寇可往,我复亦往


关注【OpenCV学堂】

长按或者扫码二维码即可关注


    您可能也对以下帖子感兴趣

    文章有问题?点此查看未经处理的缓存